{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YVfdLsnHFrZ6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "mSVgbpixegCb"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FZ2T25awqPPX"
      },
      "source": [
        "# 通过子类化创建新的层和模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Q8vYTLTqJ1i"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/keras/custom_layers_and_models\" class=\"\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" class=\"\">在 TensorFlow.org 上查看</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/keras/custom_layers_and_models.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/keras/custom_layers_and_models.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 上查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/keras/custom_layers_and_models.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\">下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kKyceCkqt8kO"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lhsGftkq8III"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow import keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ox9WJETD27aT"
      },
      "source": [
        "## `Layer` 类：状态（权重）和部分计算的组合\n",
        "\n",
        "Keras 的一个中心抽象是 `Layer` 类。层封装了状态（层的“权重”）和从输入到输出的转换（“调用”，即层的前向传递）。\n",
        "\n",
        "下面是一个密集连接的层。它具有一个状态：变量 `w` 和 `b`。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "psLn5u4RUhru"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32, input_dim=32):\n",
        "        super(Linear, self).__init__()\n",
        "        w_init = tf.random_normal_initializer()\n",
        "        self.w = tf.Variable(\n",
        "            initial_value=w_init(shape=(input_dim, units), dtype=\"float32\"),\n",
        "            trainable=True,\n",
        "        )\n",
        "        b_init = tf.zeros_initializer()\n",
        "        self.b = tf.Variable(\n",
        "            initial_value=b_init(shape=(units,), dtype=\"float32\"), trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xa4ZKbcgY2ws"
      },
      "source": [
        "您可以在某些张量输入上通过调用来使用层，这一点很像 Python 函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wCFLPZvTII3w"
      },
      "outputs": [],
      "source": [
        "x = tf.ones((2, 2))\n",
        "linear_layer = Linear(4, 2)\n",
        "y = linear_layer(x)\n",
        "print(y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VU6qtjLJ7OoU"
      },
      "source": [
        "请注意，权重 `w` 和 `b` 在被设置为层特性后会由层自动跟踪："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jWeZhh3V7wJa"
      },
      "outputs": [],
      "source": [
        "assert linear_layer.weights == [linear_layer.w, linear_layer.b]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u8UFOXQ7wND7"
      },
      "source": [
        "请注意，您还可以使用一种更加快捷的方式为层添加权重：`add_weight()` 方法："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JrVpdNTTIVB7"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32, input_dim=32):\n",
        "        super(Linear, self).__init__()\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_dim, units), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "        self.b = self.add_weight(shape=(units,), initializer=\"zeros\", trainable=True)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "\n",
        "x = tf.ones((2, 2))\n",
        "linear_layer = Linear(4, 2)\n",
        "y = linear_layer(x)\n",
        "print(y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0FWNOtdYhuJy"
      },
      "source": [
        "## 层可以具有不可训练权重\n",
        "\n",
        "除了可训练权重外，您还可以向层添加不可训练权重。对层进行训练时，不必在反向传播期间考虑此类权重。\n",
        "\n",
        "以下是添加和使用不可训练权重的方法："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6yTiS9tI2G8c"
      },
      "outputs": [],
      "source": [
        "class ComputeSum(keras.layers.Layer):\n",
        "    def __init__(self, input_dim):\n",
        "        super(ComputeSum, self).__init__()\n",
        "        self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        self.total.assign_add(tf.reduce_sum(inputs, axis=0))\n",
        "        return self.total\n",
        "\n",
        "\n",
        "x = tf.ones((2, 2))\n",
        "my_sum = ComputeSum(2)\n",
        "y = my_sum(x)\n",
        "print(y.numpy())\n",
        "y = my_sum(x)\n",
        "print(y.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I2jZ13RpWSJZ"
      },
      "source": [
        "它是 `layer.weights` 的一部分，但被归类为不可训练权重："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zHmQSpo4D4QQ"
      },
      "outputs": [],
      "source": [
        "print(\"weights:\", len(my_sum.weights))\n",
        "print(\"non-trainable weights:\", len(my_sum.non_trainable_weights))\n",
        "\n",
        "# It's not included in the trainable weights:\n",
        "print(\"trainable_weights:\", my_sum.trainable_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wjFeyyncwHsn"
      },
      "source": [
        "## 最佳做法：将权重创建推迟到得知输入的形状之后\n",
        "\n",
        "上面的 `Linear` 层接受了一个 `input_dim` 参数，用于计算 `__init__()` 中权重 `w` 和 `b` 的形状："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N9B8k9LtrJbl"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32, input_dim=32):\n",
        "        super(Linear, self).__init__()\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_dim, units), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "        self.b = self.add_weight(shape=(units,), initializer=\"zeros\", trainable=True)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aWx8tcSpyO7n"
      },
      "source": [
        "在许多情况下，您可能事先不知道输入的大小，并希望在得知该值时（对层进行实例化后的某个时间）再延迟创建权重。\n",
        "\n",
        "在 Keras API 中，我们建议您在层的 `build(self, inputs_shape)` 方法中创建层的权重。如下所示："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cow04aRb83DJ"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32):\n",
        "        super(Linear, self).__init__()\n",
        "        self.units = units\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_shape[-1], self.units),\n",
        "            initializer=\"random_normal\",\n",
        "            trainable=True,\n",
        "        )\n",
        "        self.b = self.add_weight(\n",
        "            shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ButahocyBR7y"
      },
      "source": [
        "层的 `__call__()` 方法将在首次调用时自动运行构建。现在，您有了一个延迟并因此更易使用的层："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8E8ufQGU2r8u"
      },
      "outputs": [],
      "source": [
        "# At instantiation, we don't know on what inputs this is going to get called\n",
        "linear_layer = Linear(32)\n",
        "\n",
        "# The layer's weights are created dynamically the first time the layer is called\n",
        "y = linear_layer(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YYE0fkm8bIg1"
      },
      "source": [
        "## 层可递归组合\n",
        "\n",
        "如果将一个层实例分配为另一个层的特性，则外部层将开始跟踪内部层的权重。\n",
        "\n",
        "我们建议在 `__init__()` 方法中创建此类子层（由于子层通常具有构建方法，它们将与外部层同时构建）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hJr89U0qOJqR"
      },
      "outputs": [],
      "source": [
        "# Let's assume we are reusing the Linear class\n",
        "# with a `build` method that we defined above.\n",
        "\n",
        "\n",
        "class MLPBlock(keras.layers.Layer):\n",
        "    def __init__(self):\n",
        "        super(MLPBlock, self).__init__()\n",
        "        self.linear_1 = Linear(32)\n",
        "        self.linear_2 = Linear(32)\n",
        "        self.linear_3 = Linear(1)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        x = self.linear_1(inputs)\n",
        "        x = tf.nn.relu(x)\n",
        "        x = self.linear_2(x)\n",
        "        x = tf.nn.relu(x)\n",
        "        return self.linear_3(x)\n",
        "\n",
        "\n",
        "mlp = MLPBlock()\n",
        "y = mlp(tf.ones(shape=(3, 64)))  # The first call to the `mlp` will create the weights\n",
        "print(\"weights:\", len(mlp.weights))\n",
        "print(\"trainable weights:\", len(mlp.trainable_weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wQUEr9dZqxaN"
      },
      "source": [
        "## `add_loss()` 方法\n",
        "\n",
        "在编写层的 `call()` 方法时，您可以在编写训练循环时创建需要稍后使用的损失张量。这可以通过调用 `self.add_loss(value)` 来实现："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W99EJ41lQhOR"
      },
      "outputs": [],
      "source": [
        "# A layer that creates an activity regularization loss\n",
        "class ActivityRegularizationLayer(keras.layers.Layer):\n",
        "    def __init__(self, rate=1e-2):\n",
        "        super(ActivityRegularizationLayer, self).__init__()\n",
        "        self.rate = rate\n",
        "\n",
        "    def call(self, inputs):\n",
        "        self.add_loss(self.rate * tf.reduce_sum(inputs))\n",
        "        return inputs\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gjB17DfePyIm"
      },
      "source": [
        "这些损失（包括由任何内部层创建的损失）可通过 `layer.losses` 取回。此属性会在每个 `__call__()` 开始时重置到顶层，因此 `layer.losses` 始终包含在上一次前向传递过程中创建的损失值。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "48QiYwJRIe96"
      },
      "outputs": [],
      "source": [
        "class OuterLayer(keras.layers.Layer):\n",
        "    def __init__(self):\n",
        "        super(OuterLayer, self).__init__()\n",
        "        self.activity_reg = ActivityRegularizationLayer(1e-2)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return self.activity_reg(inputs)\n",
        "\n",
        "\n",
        "layer = OuterLayer()\n",
        "assert len(layer.losses) == 0  # No losses yet since the layer has never been called\n",
        "\n",
        "_ = layer(tf.zeros(1, 1))\n",
        "assert len(layer.losses) == 1  # We created one loss value\n",
        "\n",
        "# `layer.losses` gets reset at the start of each __call__\n",
        "_ = layer(tf.zeros(1, 1))\n",
        "assert len(layer.losses) == 1  # This is the loss created during the call above"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "72VW9A52XfTT"
      },
      "source": [
        "此外，`loss` 属性还包含为任何内部层的权重创建的正则化损失："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OBrP81sfVWUn"
      },
      "outputs": [],
      "source": [
        "class OuterLayerWithKernelRegularizer(keras.layers.Layer):\n",
        "    def __init__(self):\n",
        "        super(OuterLayerWithKernelRegularizer, self).__init__()\n",
        "        self.dense = keras.layers.Dense(\n",
        "            32, kernel_regularizer=tf.keras.regularizers.l2(1e-3)\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return self.dense(inputs)\n",
        "\n",
        "\n",
        "layer = OuterLayerWithKernelRegularizer()\n",
        "_ = layer(tf.zeros((1, 1)))\n",
        "\n",
        "# This is `1e-3 * sum(layer.dense.kernel ** 2)`,\n",
        "# created by the `kernel_regularizer` above.\n",
        "print(layer.losses)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7P88QF0urVzA"
      },
      "source": [
        "在编写训练循环时应考虑这些损失，如下所示：\n",
        "\n",
        "```python\n",
        "# Instantiate an optimizer. optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3) loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)  # Iterate over the batches of a dataset. for x_batch_train, y_batch_train in train_dataset:   with tf.GradientTape() as tape:     logits = layer(x_batch_train)  # Logits for this minibatch     # Loss value for this minibatch     loss_value = loss_fn(y_batch_train, logits)     # Add extra losses created during this forward pass:     loss_value += sum(model.losses)    grads = tape.gradient(loss_value, model.trainable_weights)   optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Wo9mCyvIheY"
      },
      "source": [
        "有关编写训练循环的详细指南，请参阅[从头开始编写训练循环](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch/)指南。\n",
        "\n",
        "这些损失还可以无缝使用 `fit()`（它们会自动求和并添加到主损失中，如果有的话）："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "prhlyHQU8eqT"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "inputs = keras.Input(shape=(3,))\n",
        "outputs = ActivityRegularizationLayer()(inputs)\n",
        "model = keras.Model(inputs, outputs)\n",
        "\n",
        "# If there is a loss passed in `compile`, thee regularization\n",
        "# losses get added to it\n",
        "model.compile(optimizer=\"adam\", loss=\"mse\")\n",
        "model.fit(np.random.random((2, 3)), np.random.random((2, 3)))\n",
        "\n",
        "# It's also possible not to pass any loss in `compile`,\n",
        "# since the model already has a loss to minimize, via the `add_loss`\n",
        "# call during the forward pass!\n",
        "model.compile(optimizer=\"adam\")\n",
        "model.fit(np.random.random((2, 3)), np.random.random((2, 3)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5Fw96x84S8HS"
      },
      "source": [
        "## `add_metric()` 方法\n",
        "\n",
        "与 `add_loss()` 类似，层还具有 `add_metric()` 方法，用于在训练过程中跟踪数量的移动平均值。\n",
        "\n",
        "请思考下面的 \"logistic endpoint\" 层。它将预测和目标作为输入，计算通过 `add_loss()` 跟踪的损失，并计算通过 `add_metric()` 跟踪的准确率标量。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kjhgiCfke2Dx"
      },
      "outputs": [],
      "source": [
        "class LogisticEndpoint(keras.layers.Layer):\n",
        "    def __init__(self, name=None):\n",
        "        super(LogisticEndpoint, self).__init__(name=name)\n",
        "        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n",
        "        self.accuracy_fn = keras.metrics.BinaryAccuracy()\n",
        "\n",
        "    def call(self, targets, logits, sample_weights=None):\n",
        "        # Compute the training-time loss value and add it\n",
        "        # to the layer using `self.add_loss()`.\n",
        "        loss = self.loss_fn(targets, logits, sample_weights)\n",
        "        self.add_loss(loss)\n",
        "\n",
        "        # Log accuracy as a metric and add it\n",
        "        # to the layer using `self.add_metric()`.\n",
        "        acc = self.accuracy_fn(targets, logits, sample_weights)\n",
        "        self.add_metric(acc, name=\"accuracy\")\n",
        "\n",
        "        # Return the inference-time prediction tensor (for `.predict()`).\n",
        "        return tf.nn.softmax(logits)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eEcShLRgzCcp"
      },
      "source": [
        "可通过 `layer.metrics` 访问以这种方式跟踪的指标："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vK9Y6ierSyqd"
      },
      "outputs": [],
      "source": [
        "layer = LogisticEndpoint()\n",
        "\n",
        "targets = tf.ones((2, 2))\n",
        "logits = tf.ones((2, 2))\n",
        "y = layer(targets, logits)\n",
        "\n",
        "print(\"layer.metrics:\", layer.metrics)\n",
        "print(\"current accuracy value:\", float(layer.metrics[0].result()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ogfTmUPmHqPS"
      },
      "source": [
        "和 `add_loss()` 一样，这些指标也是通过 `fit()` 跟踪的："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YEzldbjt7oF8"
      },
      "outputs": [],
      "source": [
        "inputs = keras.Input(shape=(3,), name=\"inputs\")\n",
        "targets = keras.Input(shape=(10,), name=\"targets\")\n",
        "logits = keras.layers.Dense(10)(inputs)\n",
        "predictions = LogisticEndpoint(name=\"predictions\")(logits, targets)\n",
        "\n",
        "model = keras.Model(inputs=[inputs, targets], outputs=predictions)\n",
        "model.compile(optimizer=\"adam\")\n",
        "\n",
        "data = {\n",
        "    \"inputs\": np.random.random((3, 3)),\n",
        "    \"targets\": np.random.random((3, 10)),\n",
        "}\n",
        "model.fit(data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jvZNZlgFFPGW"
      },
      "source": [
        "## 可选择在层上启用序列化\n",
        "\n",
        "如果需要将自定义层作为[函数式模型](https://tensorflow.google.cn/guide/keras/functional/)的一部分进行序列化，您可以选择实现 `get_config()` 方法："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3EWDINbqDVRZ"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32):\n",
        "        super(Linear, self).__init__()\n",
        "        self.units = units\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_shape[-1], self.units),\n",
        "            initializer=\"random_normal\",\n",
        "            trainable=True,\n",
        "        )\n",
        "        self.b = self.add_weight(\n",
        "            shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "    def get_config(self):\n",
        "        return {\"units\": self.units}\n",
        "\n",
        "\n",
        "# Now you can recreate the layer from its config:\n",
        "layer = Linear(64)\n",
        "config = layer.get_config()\n",
        "print(config)\n",
        "new_layer = Linear.from_config(config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RBvNVeGPxhRE"
      },
      "source": [
        "请注意，基本 `Layer` 类的 `__init__()` 方法会接受一些关键字参数，尤其是 `name` 和 `dtype`。最好将这些参数传递给 `__init__()` 中的父类，并将其包含在层配置中："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fRdhwwqKTZxI"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32, **kwargs):\n",
        "        super(Linear, self).__init__(**kwargs)\n",
        "        self.units = units\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_shape[-1], self.units),\n",
        "            initializer=\"random_normal\",\n",
        "            trainable=True,\n",
        "        )\n",
        "        self.b = self.add_weight(\n",
        "            shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "    def get_config(self):\n",
        "        config = super(Linear, self).get_config()\n",
        "        config.update({\"units\": self.units})\n",
        "        return config\n",
        "\n",
        "\n",
        "layer = Linear(64)\n",
        "config = layer.get_config()\n",
        "print(config)\n",
        "new_layer = Linear.from_config(config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cqHlieM5gqgE"
      },
      "source": [
        "如果根据层的配置对层进行反序列化时需要更大的灵活性，还可以重写 `from_config()` 类方法。下面是 `from_config()` 的基本实现：\n",
        "\n",
        "```python\n",
        "def from_config(cls, config):   return cls(**config)\n",
        "```\n",
        "\n",
        "要详细了解序列化和保存，请参阅完整的[保存和序列化模型](https://tensorflow.google.cn/guide/keras/save_and_serialize/)指南。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qCXh57DUwJd8"
      },
      "source": [
        "## `call()` 方法中的特权 `training` 参数\n",
        "\n",
        "某些层，尤其是 `BatchNormalization` 层和 `Dropout` 层，在训练和推断期间具有不同的行为。对于此类层，标准做法是在 `call()` 方法中公开 `training`（布尔）参数。\n",
        "\n",
        "通过在 `call()` 中公开此参数，可以启用内置的训练和评估循环（例如 `fit()`）以在训练和推断中正确使用层。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BP35dXiYoTcD"
      },
      "outputs": [],
      "source": [
        "class CustomDropout(keras.layers.Layer):\n",
        "    def __init__(self, rate, **kwargs):\n",
        "        super(CustomDropout, self).__init__(**kwargs)\n",
        "        self.rate = rate\n",
        "\n",
        "    def call(self, inputs, training=None):\n",
        "        if training:\n",
        "            return tf.nn.dropout(inputs, rate=self.rate)\n",
        "        return inputs\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OudWKhz1iQXP"
      },
      "source": [
        "## `call()` 方法中的特权 `mask` 参数\n",
        "\n",
        "`call()` 支持的另一个特权参数是 `mask` 参数。\n",
        "\n",
        "它会出现在所有 Keras RNN 层中。掩码是布尔张量（在输入中每个时间步骤对应一个布尔值），用于在处理时间序列数据时跳过某些输入时间步骤。\n",
        "\n",
        "当先前的层生成了掩码时，Keras 会自动将正确的 `mask` 参数传递给 `__call__()`（针对支持它的层）。掩码生成层是配置了 `mask_zero=True` 的 `Embedding` 层和 `Masking` 层。\n",
        "\n",
        "要了解有关遮盖以及如何编写启用了遮盖的图层的详细信息，请查看[“了解填充和遮盖”](https://tensorflow.google.cn/guide/keras/masking_and_padding/)指南。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6n3ZYo8abHGd"
      },
      "source": [
        "## `Model` 类\n",
        "\n",
        "通常，您会使用 `Layer` 类来定义内部计算块，并使用 `Model` 类来定义外部模型，即您将训练的对象。\n",
        "\n",
        "例如，在 ResNet50 模型中，您会有几个子类化 `Layer` 的 ResNet 块，和一个包含了整个 ResNet50 网络的单个 `Model`。\n",
        "\n",
        "`Model` 类具有与 `Layer` 相同的 API，但有如下区别：\n",
        "\n",
        "- 它会公开内置训练、评估和预测循环（`model.fit()`、`model.evaluate()`、`model.predict()`）。\n",
        "- 它会通过 `model.layers` 属性公开其内部层的列表。\n",
        "- 它会公开保存和序列化 API（`save()`、`save_weights()`…）\n",
        "\n",
        "实际上，`Layer` 类对应于我们在文献中所称的“层”（如“卷积层”或“循环层”）或“块”（如“ResNet 块”或“Inception 块”）。\n",
        "\n",
        "同时，`Model` 类对应于文献中所称的“模型”（如“深度学习模型”）或“网络”（如“深度神经网络”）。\n",
        "\n",
        "因此，如果您想知道“我应该用 `Layer` 类还是 `Model` 类？”，请问自己：我是否需要在它上面调用 `fit()`？我是否需要在它上面调用 `save()`？如果是，则使用 `Model`。如果不是（要么因为您的类只是更大系统中的一个块，要么因为您正在自己编写训练和保存代码），则使用 `Layer`。\n",
        "\n",
        "例如，我们可以拿上面的 mini-resnet 示例为例，用它来构建一个 `Model`，该模型可以通过 `fit()` 进行训练，并可以通过 `save_weights()` 来保存："
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FNGSP4aBm6fv"
      },
      "source": [
        "```python\n",
        "class ResNet(tf.keras.Model):      def __init__(self):         super(ResNet, self).__init__()         self.block_1 = ResNetBlock()         self.block_2 = ResNetBlock()         self.global_pool = layers.GlobalAveragePooling2D()         self.classifier = Dense(num_classes)      def call(self, inputs):         x = self.block_1(inputs)         x = self.block_2(x)         x = self.global_pool(x)         return self.classifier(x)   resnet = ResNet() dataset = ... resnet.fit(dataset, epochs=10) resnet.save(filepath)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cdhz04jvl5pu"
      },
      "source": [
        "## 汇总：端到端示例\n",
        "\n",
        "到目前为止，您已学习以下内容：\n",
        "\n",
        "- `Layer` 封装了（在 `__init__()` 或 `build()` 中创建的）状态和（在 `call()` 中定义的）部分计算。\n",
        "- 层可以递归嵌套以创建新的更大的计算块。\n",
        "- 层可以通过 `add_loss()` 和 `add_metric()` 创建并跟踪损失（通常是正则化损失）以及指标。\n",
        "- 您要训练的外部容器是 `Model`。`Model` 就像 `Layer`，但是添加了训练和序列化实用工具。\n",
        "\n",
        "让我们将这些内容全部汇总到一个完整示例：我们将实现一个变分自动编码器 (VAE)，并用 MNIST 数字对其进行训练。\n",
        "\n",
        "我们的 VAE 将是 `Model` 的一个子类，它是作为子类化 `Layer` 的嵌套组合层进行构建的。它将具有正则化损失（KL 散度）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MkWGvP01ZsIu"
      },
      "outputs": [],
      "source": [
        "from tensorflow.keras import layers\n",
        "\n",
        "\n",
        "class Sampling(layers.Layer):\n",
        "    \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n",
        "\n",
        "    def call(self, inputs):\n",
        "        z_mean, z_log_var = inputs\n",
        "        batch = tf.shape(z_mean)[0]\n",
        "        dim = tf.shape(z_mean)[1]\n",
        "        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))\n",
        "        return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n",
        "\n",
        "\n",
        "class Encoder(layers.Layer):\n",
        "    \"\"\"Maps MNIST digits to a triplet (z_mean, z_log_var, z).\"\"\"\n",
        "\n",
        "    def __init__(self, latent_dim=32, intermediate_dim=64, name=\"encoder\", **kwargs):\n",
        "        super(Encoder, self).__init__(name=name, **kwargs)\n",
        "        self.dense_proj = layers.Dense(intermediate_dim, activation=\"relu\")\n",
        "        self.dense_mean = layers.Dense(latent_dim)\n",
        "        self.dense_log_var = layers.Dense(latent_dim)\n",
        "        self.sampling = Sampling()\n",
        "\n",
        "    def call(self, inputs):\n",
        "        x = self.dense_proj(inputs)\n",
        "        z_mean = self.dense_mean(x)\n",
        "        z_log_var = self.dense_log_var(x)\n",
        "        z = self.sampling((z_mean, z_log_var))\n",
        "        return z_mean, z_log_var, z\n",
        "\n",
        "\n",
        "class Decoder(layers.Layer):\n",
        "    \"\"\"Converts z, the encoded digit vector, back into a readable digit.\"\"\"\n",
        "\n",
        "    def __init__(self, original_dim, intermediate_dim=64, name=\"decoder\", **kwargs):\n",
        "        super(Decoder, self).__init__(name=name, **kwargs)\n",
        "        self.dense_proj = layers.Dense(intermediate_dim, activation=\"relu\")\n",
        "        self.dense_output = layers.Dense(original_dim, activation=\"sigmoid\")\n",
        "\n",
        "    def call(self, inputs):\n",
        "        x = self.dense_proj(inputs)\n",
        "        return self.dense_output(x)\n",
        "\n",
        "\n",
        "class VariationalAutoEncoder(keras.Model):\n",
        "    \"\"\"Combines the encoder and decoder into an end-to-end model for training.\"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        original_dim,\n",
        "        intermediate_dim=64,\n",
        "        latent_dim=32,\n",
        "        name=\"autoencoder\",\n",
        "        **kwargs\n",
        "    ):\n",
        "        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)\n",
        "        self.original_dim = original_dim\n",
        "        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)\n",
        "        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        z_mean, z_log_var, z = self.encoder(inputs)\n",
        "        reconstructed = self.decoder(z)\n",
        "        # Add KL divergence regularization loss.\n",
        "        kl_loss = -0.5 * tf.reduce_mean(\n",
        "            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1\n",
        "        )\n",
        "        self.add_loss(kl_loss)\n",
        "        return reconstructed\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uBAILhN56qaB"
      },
      "source": [
        "让我们在 MNIST 上编写一个简单的训练循环："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3WDjPVGOvvMR"
      },
      "outputs": [],
      "source": [
        "original_dim = 784\n",
        "vae = VariationalAutoEncoder(original_dim, 64, 32)\n",
        "\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
        "mse_loss_fn = tf.keras.losses.MeanSquaredError()\n",
        "\n",
        "loss_metric = tf.keras.metrics.Mean()\n",
        "\n",
        "(x_train, _), _ = tf.keras.datasets.mnist.load_data()\n",
        "x_train = x_train.reshape(60000, 784).astype(\"float32\") / 255\n",
        "\n",
        "train_dataset = tf.data.Dataset.from_tensor_slices(x_train)\n",
        "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n",
        "\n",
        "epochs = 2\n",
        "\n",
        "# Iterate over epochs.\n",
        "for epoch in range(epochs):\n",
        "    print(\"Start of epoch %d\" % (epoch,))\n",
        "\n",
        "    # Iterate over the batches of the dataset.\n",
        "    for step, x_batch_train in enumerate(train_dataset):\n",
        "        with tf.GradientTape() as tape:\n",
        "            reconstructed = vae(x_batch_train)\n",
        "            # Compute reconstruction loss\n",
        "            loss = mse_loss_fn(x_batch_train, reconstructed)\n",
        "            loss += sum(vae.losses)  # Add KLD regularization loss\n",
        "\n",
        "        grads = tape.gradient(loss, vae.trainable_weights)\n",
        "        optimizer.apply_gradients(zip(grads, vae.trainable_weights))\n",
        "\n",
        "        loss_metric(loss)\n",
        "\n",
        "        if step % 100 == 0:\n",
        "            print(\"step %d: mean loss = %.4f\" % (step, loss_metric.result()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yUWIKs8Nx70G"
      },
      "source": [
        "请注意，由于 VAE 是 `Model` 子类化的结果，它具有内置的训练循环。因此，您也可以用以下方式训练它："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cuadnNIqW0Jb"
      },
      "outputs": [],
      "source": [
        "vae = VariationalAutoEncoder(784, 64, 32)\n",
        "\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
        "\n",
        "vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())\n",
        "vae.fit(x_train, x_train, epochs=2, batch_size=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RiujG1u0qLp0"
      },
      "source": [
        "## 超越面向对象的开发：函数式 API\n",
        "\n",
        "这个示例对您来说是否包含了太多面向对象的开发？您也可以使用[函数式 API](https://tensorflow.google.cn/guide/keras/functional/) 来构建模型。重要的是，选择其中一种样式并不妨碍您利用以另一种样式编写的组件：您可以始终混合搭配。\n",
        "\n",
        "例如，下面的函数式 API 示例重用了我们在上面示例中定义的同一个 `Sampling` 层："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "el7h9ajsqCR9"
      },
      "outputs": [],
      "source": [
        "original_dim = 784\n",
        "intermediate_dim = 64\n",
        "latent_dim = 32\n",
        "\n",
        "# Define encoder model.\n",
        "original_inputs = tf.keras.Input(shape=(original_dim,), name=\"encoder_input\")\n",
        "x = layers.Dense(intermediate_dim, activation=\"relu\")(original_inputs)\n",
        "z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
        "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
        "z = Sampling()((z_mean, z_log_var))\n",
        "encoder = tf.keras.Model(inputs=original_inputs, outputs=z, name=\"encoder\")\n",
        "\n",
        "# Define decoder model.\n",
        "latent_inputs = tf.keras.Input(shape=(latent_dim,), name=\"z_sampling\")\n",
        "x = layers.Dense(intermediate_dim, activation=\"relu\")(latent_inputs)\n",
        "outputs = layers.Dense(original_dim, activation=\"sigmoid\")(x)\n",
        "decoder = tf.keras.Model(inputs=latent_inputs, outputs=outputs, name=\"decoder\")\n",
        "\n",
        "# Define VAE model.\n",
        "outputs = decoder(z)\n",
        "vae = tf.keras.Model(inputs=original_inputs, outputs=outputs, name=\"vae\")\n",
        "\n",
        "# Add KL divergence regularization loss.\n",
        "kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)\n",
        "vae.add_loss(kl_loss)\n",
        "\n",
        "# Train.\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
        "vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())\n",
        "vae.fit(x_train, x_train, epochs=3, batch_size=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VsMiP9ABJHWF"
      },
      "source": [
        "有关详细信息，请务必阅读[函数式 API](https://tensorflow.google.cn/guide/keras/functional/) 指南。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_layers_and_models.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
